#include "sot_trt_cuda_impl.h"
#include <cmath>
#include <fstream>
#include <opencv2/opencv.hpp>

#define SV_MODEL_DIR "/SpireCV/models/"
#define SV_ROOT_DIR "/SpireCV/"

#ifdef WITH_CUDA
#include "yolov7/logging.h"
#define TRTCHECK(status)                                 \
  do                                                     \
  {                                                      \
    auto ret = (status);                                 \
    if (ret != 0)                                        \
    {                                                    \
      std::cerr << "Cuda failure: " << ret << std::endl; \
      abort();                                           \
    }                                                    \
  } while (0)

#define DEVICE 0 // GPU id
#define BATCH_SIZE 1
#define MAX_IMAGE_INPUT_SIZE_THRESH 3000 * 3000 // ensure it exceed the maximum size in the input images !
#endif

namespace sv
{

  using namespace std;
  using namespace nvinfer1;
  static Logger g_nvlogger;

  inline float fast_exp(float x)
  {
    union
    {
      uint32_t i;
      float f;
    } v{};
    v.i = (1 << 23) * (1.4426950409 * x + 126.93490512f);
    return v.f;
  }

  inline float sigmoid(float x)
  {
    return 1.0f / (1.0f + fast_exp(-x));
  }

  static void Softmax(const Mat &src, Mat &dst)
  {
    Mat maxVal;
    cv::max(src.row(1), src.row(0), maxVal);

    src.row(1) -= maxVal;
    src.row(0) -= maxVal;

    exp(src, dst);

    Mat sumVal = dst.row(0) + dst.row(1);
    dst.row(0) = dst.row(0) / sumVal;
    dst.row(1) = dst.row(1) / sumVal;
  }

  static float sz_whFun(cv::Point2f wh)
  {
    float pad = (wh.x + wh.y) * 0.5f;
    float sz2 = (wh.x + pad) * (wh.y + pad);
    return std::sqrt(sz2);
  }

  static std::vector<float> sz_change_fun(std::vector<float> w, std::vector<float> h, float sz)
  {
    int rows = int(std::sqrt(w.size()));
    int cols = int(std::sqrt(w.size()));
    std::vector<float> pad(rows * cols, 0);
    std::vector<float> sz2;
    for (int i = 0; i < cols; i++)
    {
      for (int j = 0; j < rows; j++)
      {
        pad[i * cols + j] = (w[i * cols + j] + h[i * cols + j]) * 0.5f;
      }
    }
    for (int i = 0; i < cols; i++)
    {
      for (int j = 0; j < rows; j++)
      {
        float t = std::sqrt((w[i * rows + j] + pad[i * rows + j]) * (h[i * rows + j] + pad[i * rows + j])) / sz;
        sz2.push_back(std::max(t, (float)1.0 / t));
      }
    }
    return sz2;
  }

  static std::vector<float> ratio_change_fun(std::vector<float> w, std::vector<float> h, cv::Point2f target_sz)
  {
    int rows = int(std::sqrt(w.size()));
    int cols = int(std::sqrt(w.size()));
    float ratio = target_sz.x / target_sz.y;
    std::vector<float> sz2;
    for (int i = 0; i < rows; i++)
    {
      for (int j = 0; j < cols; j++)
      {
        float t = ratio / (w[i * cols + j] / h[i * cols + j]);
        sz2.push_back(std::max(t, (float)1.0 / t));
      }
    }

    return sz2;
  }

  SotTrtDetectorCUDAImpl::SotTrtDetectorCUDAImpl()
  {
  }

  SotTrtDetectorCUDAImpl::~SotTrtDetectorCUDAImpl()
  {
  }

  bool SotTrtDetectorCUDAImpl::cudaSetup()
  {

    std::string trt_model_fn1 = get_home() + SV_MODEL_DIR + "nanotrack_backbone_sim.engine";
    std::string trt_model_fn2 = get_home() + SV_MODEL_DIR + "nanotrack_backbone_temp.engine";
    std::string trt_model_fn3 = get_home() + SV_MODEL_DIR + "nanotrack_head_sim.engine";
    if (!is_file_exist(trt_model_fn1) && !is_file_exist(trt_model_fn2) && !is_file_exist(trt_model_fn3))
    {
      throw std::runtime_error("SpireCV (104) Error loading the Nanotrack TensorRT model (File Not Exist)");
    }
    char *trt_model_stream1{nullptr}, *trt_model_stream2{nullptr}, *trt_model_stream3{nullptr};
    size_t trt_model_size1{0}, trt_model_size2{0}, trt_model_size3{0};
    ;
    try
    {
      std::ifstream file1(trt_model_fn1, std::ios::binary);
      file1.seekg(0, file1.end);
      trt_model_size1 = file1.tellg();
      file1.seekg(0, file1.beg);
      trt_model_stream1 = new char[trt_model_size1];
      assert(trt_model_stream1);
      file1.read(trt_model_stream1, trt_model_size1);
      file1.close();
    }
    catch (const std::runtime_error &e)
    {
      throw std::runtime_error("SpireCV (104) Error loading the TensorRT model!");
    }

    try
    {
      std::ifstream file2(trt_model_fn2, std::ios::binary);
      file2.seekg(0, file2.end);
      trt_model_size2 = file2.tellg();
      file2.seekg(0, file2.beg);
      trt_model_stream2 = new char[trt_model_size2];
      assert(trt_model_stream2);
      file2.read(trt_model_stream2, trt_model_size2);
      file2.close();
    }
    catch (const std::runtime_error &e)
    {
      throw std::runtime_error("SpireCV (104) Error loading the TensorRT model!");
    }

    try
    {
      std::ifstream file3(trt_model_fn3, std::ios::binary);
      file3.seekg(0, file3.end);
      trt_model_size3 = file3.tellg();
      file3.seekg(0, file3.beg);
      trt_model_stream3 = new char[trt_model_size3];
      assert(trt_model_stream3);
      file3.read(trt_model_stream3, trt_model_size3);
      file3.close();
    }
    catch (const std::runtime_error &e)
    {
      throw std::runtime_error("SpireCV (104) Error loading the TensorRT model!");
    }

    // TensorRT1
    IRuntime *runtime_1 = nvinfer1::createInferRuntime(g_nvlogger);
    assert(runtime_1 != nullptr);
    ICudaEngine *p_cu_engine1 = runtime_1->deserializeCudaEngine(trt_model_stream1, trt_model_size1);
    assert(p_cu_engine1 != nullptr);
    this->_trt_context[0] = p_cu_engine1->createExecutionContext();
    assert(this->_trt_context[0] != nullptr);

    delete[] trt_model_stream1;
    const ICudaEngine &cu_engine1 = this->_trt_context[0]->getEngine();
    assert(cu_engine1.getNbBindings() == 2);

    this->_input_index_1 = cu_engine1.getBindingIndex("input");
    this->_output_index_1 = cu_engine1.getBindingIndex("output");
    TRTCHECK(cudaMalloc(&_p_buffers_1[this->_input_index_1], 1 * 3 * 255 * 255 * sizeof(float)));
    TRTCHECK(cudaMalloc(&_p_buffers_1[this->_output_index_1], 1 * 48 * 16 * 16 * sizeof(float)));
    TRTCHECK(cudaStreamCreate(&_cu_stream1));

    auto input_dims_1 = nvinfer1::Dims4{1, 3, 255, 255};
    this->_trt_context[0]->setBindingDimensions(this->_input_index_1, input_dims_1);

    this->_p_data_1 = new float[1 * 3 * 255 * 255];
    this->_p_prob_1 = new float[1 * 48 * 16 * 16];
    // Input
    TRTCHECK(cudaMemcpyAsync(_p_buffers_1[this->_input_index_1], this->_p_data_1, 1 * 3 * 255 * 255 * sizeof(float), cudaMemcpyHostToDevice, this->_cu_stream1));
    // this->_trt_context->enqueue(1, _p_buffers, this->_cu_stream, nullptr);
    this->_trt_context[0]->enqueueV2(_p_buffers_1, this->_cu_stream1, nullptr);
    // Output
    TRTCHECK(cudaMemcpyAsync(this->_p_prob_1, _p_buffers_1[this->_output_index_1], 1 * 48 * 16 * 16 * sizeof(float), cudaMemcpyDeviceToHost, this->_cu_stream1));
    cudaStreamSynchronize(this->_cu_stream1);

    // TensorRT2
    IRuntime *runtime_2 = nvinfer1::createInferRuntime(g_nvlogger);
    assert(runtime_2 != nullptr);
    ICudaEngine *p_cu_engine2 = runtime_2->deserializeCudaEngine(trt_model_stream2, trt_model_size2);
    assert(p_cu_engine2 != nullptr);
    this->_trt_context[1] = p_cu_engine2->createExecutionContext();
    assert(this->_trt_context[1] != nullptr);

    delete[] trt_model_stream2;
    const ICudaEngine &cu_engine2 = this->_trt_context[1]->getEngine();
    assert(cu_engine2.getNbBindings() == 2);

    this->_input_index_2 = cu_engine2.getBindingIndex("input");
    this->_output_index_2 = cu_engine2.getBindingIndex("output");
    TRTCHECK(cudaMalloc(&_p_buffers_2[this->_input_index_2], 1 * 3 * 127 * 127 * sizeof(float)));
    TRTCHECK(cudaMalloc(&_p_buffers_2[this->_output_index_2], 1 * 48 * 8 * 8 * sizeof(float)));
    TRTCHECK(cudaStreamCreate(&_cu_stream2));

    auto input_dims_2 = nvinfer1::Dims4{1, 3, 127, 127};
    this->_trt_context[1]->setBindingDimensions(this->_input_index_2, input_dims_2);

    this->_p_data_2 = new float[1 * 3 * 127 * 127];
    this->_p_prob_2 = new float[1 * 48 * 8 * 8];
    // Input
    TRTCHECK(cudaMemcpyAsync(_p_buffers_2[this->_input_index_2], this->_p_data_2, 1 * 3 * 127 * 127 * sizeof(float), cudaMemcpyHostToDevice, this->_cu_stream2));
    // this->_trt_context->enqueue(1, _p_buffers, this->_cu_stream, nullptr);
    this->_trt_context[1]->enqueueV2(_p_buffers_2, this->_cu_stream2, nullptr);
    // Output
    TRTCHECK(cudaMemcpyAsync(this->_p_prob_2, _p_buffers_2[this->_output_index_2], 1 * 48 * 8 * 8 * sizeof(float), cudaMemcpyDeviceToHost, this->_cu_stream2));
    cudaStreamSynchronize(this->_cu_stream2);

    // TensorRT3
    IRuntime *runtime_3 = nvinfer1::createInferRuntime(g_nvlogger);
    assert(runtime_3 != nullptr);
    ICudaEngine *p_cu_engine3 = runtime_3->deserializeCudaEngine(trt_model_stream3, trt_model_size3);
    assert(p_cu_engine3 != nullptr);
    this->_trt_context[2] = p_cu_engine3->createExecutionContext();
    assert(this->_trt_context[2] != nullptr);

    delete[] trt_model_stream3;
    const ICudaEngine &cu_engine3 = this->_trt_context[2]->getEngine();
    assert(cu_engine3.getNbBindings() == 4);

    this->_input_index_3_1 = cu_engine3.getBindingIndex("input1");
    this->_input_index_3_2 = cu_engine3.getBindingIndex("input2");
    this->_output_index_3_1 = cu_engine3.getBindingIndex("output1");
    this->_output_index_3_2 = cu_engine3.getBindingIndex("output2");
    TRTCHECK(cudaMalloc(&_p_buffers_3[this->_input_index_3_1], 1 * 48 * 8 * 8 * sizeof(float)));
    TRTCHECK(cudaMalloc(&_p_buffers_3[this->_input_index_3_2], 1 * 48 * 16 * 16 * sizeof(float)));
    TRTCHECK(cudaMalloc(&_p_buffers_3[this->_output_index_3_1], 1 * 2 * 16 * 16 * sizeof(float)));
    TRTCHECK(cudaMalloc(&_p_buffers_3[this->_output_index_3_2], 1 * 4 * 16 * 16 * sizeof(float)));
    TRTCHECK(cudaStreamCreate(&_cu_stream3));

    auto input_dims_3_1 = nvinfer1::Dims4{1, 48, 8, 8};
    auto input_dims_3_2 = nvinfer1::Dims4{1, 48, 16, 16};
    this->_trt_context[2]->setBindingDimensions(this->_input_index_3_1, input_dims_3_1);
    this->_trt_context[2]->setBindingDimensions(this->_input_index_3_2, input_dims_3_2);
    this->_p_data_3_1 = new float[1 * 48 * 8 * 8];
    this->_p_data_3_2 = new float[1 * 48 * 16 * 16];
    this->_p_prob_3_1 = new float[1 * 2 * 16 * 16];
    this->_p_prob_3_2 = new float[1 * 4 * 16 * 16];
    // Input
    TRTCHECK(cudaMemcpyAsync(_p_buffers_3[this->_input_index_3_1], this->_p_data_3_1, 1 * 48 * 8 * 8 * sizeof(float), cudaMemcpyHostToDevice, this->_cu_stream3));
    TRTCHECK(cudaMemcpyAsync(_p_buffers_3[this->_input_index_3_2], this->_p_data_3_2, 1 * 48 * 16 * 16 * sizeof(float), cudaMemcpyHostToDevice, this->_cu_stream3));
    // this->_trt_context->enqueue(1, _p_buffers, this->_cu_stream, nullptr);
    this->_trt_context[2]->enqueueV2(_p_buffers_3, this->_cu_stream3, nullptr);
    // Output
    TRTCHECK(cudaMemcpyAsync(this->_p_prob_3_1, _p_buffers_3[this->_output_index_3_1], 1 * 2 * 16 * 16 * sizeof(float), cudaMemcpyDeviceToHost, this->_cu_stream3));
    TRTCHECK(cudaMemcpyAsync(this->_p_prob_3_2, _p_buffers_3[this->_output_index_3_2], 1 * 4 * 16 * 16 * sizeof(float), cudaMemcpyDeviceToHost, this->_cu_stream3));
    cudaStreamSynchronize(this->_cu_stream3);

    return false;
  }

  void SotTrtDetectorCUDAImpl::cudaInitImpl(cv::Mat img, const cv::Rect &bbox)
  {
    create_window();

    create_grids();

    cv::Point target_pos;               // cx, cy
    cv::Point2f target_sz = {0.f, 0.f}; // w,h

    target_pos.x = bbox.x + bbox.width / 2;
    target_pos.y = bbox.y + bbox.height / 2;
    target_sz.x = bbox.width;
    target_sz.y = bbox.height;

    cout << "bbox" << bbox << endl;
    cout << "target_pos" << target_pos << endl;
    cout << "target_sz" << target_sz << endl;

    float wc_z = target_sz.x + cfg.context_amount * (target_sz.x + target_sz.y);
    float hc_z = target_sz.y + cfg.context_amount * (target_sz.x + target_sz.y);
    float s_z = round(sqrt(wc_z * hc_z));

    cv::Scalar avg_chans = cv::mean(img);
    cv::Mat z_crop;

    z_crop = get_subwindow_tracking(img, target_pos, cfg.exemplar_size, int(s_z), avg_chans); // cv::Mat BGR order

    // cout<<"z_crop"<<z_crop<<endl;

    for (int row = 0; row < 127; ++row)
    {
      uchar *uc_pixel = z_crop.data + row * z_crop.step; // compute row id
      for (int col = 0; col < 127; ++col)
      {
        // mean=[136.20, 141.50, 145.41], std=[44.77, 44.20, 44.30]
        this->_p_data_2[col + row * 127] = ((float)uc_pixel[0] - 136.20f) / 44.77f;
        this->_p_data_2[col + row * 127 + 127 * 127] = ((float)uc_pixel[1] - 141.50f) / 44.20f;
        this->_p_data_2[col + row * 127 + 127 * 127 * 2] = ((float)uc_pixel[2] - 145.41f) / 44.30f;
        uc_pixel += 3;
      }
    }

    // 数据输入以及模板初始化
    // TensorRT2
    // Input
    TRTCHECK(cudaMemcpyAsync(_p_buffers_2[this->_input_index_2], this->_p_data_2, 1 * 3 * 127 * 127 * sizeof(float), cudaMemcpyHostToDevice, this->_cu_stream2));
    this->_trt_context[1]->enqueueV2(_p_buffers_2, this->_cu_stream2, nullptr);

    // Output
    TRTCHECK(cudaMemcpyAsync(this->_p_prob_2, _p_buffers_2[this->_output_index_2], 1 * 48 * 8 * 8 * sizeof(float), cudaMemcpyDeviceToHost, this->_cu_stream2));
    cudaStreamSynchronize(this->_cu_stream2);

    this->state.channel_ave = avg_chans;
    this->state.im_h = img.rows;
    this->state.im_w = img.cols;
    this->state.target_pos = target_pos;
    this->state.target_sz = target_sz;
  }

  void SotTrtDetectorCUDAImpl::update(const cv::Mat &x_crops, cv::Point &target_pos, cv::Point2f &target_sz, float scale_z, float &cls_score_max)
  {

    // 图像255的输入
    // 转tensor
    // torch::Tensor tensor_image_X = torch::from_blob(x_crops.data, {1, x_crops.rows, x_crops.cols, 3}, torch::kByte);
    // tensor_image_X = tensor_image_X.permute({0, 3, 1, 2});
    // tensor_image_X = tensor_image_X.toType(torch::kFloat);

    cv::imshow("img", x_crops);
    cv::waitKey(1000);
    for (int row = 0; row < 255; ++row)
    {
      uchar *uc_pixel = x_crops.data + row * x_crops.step; // compute row id
      for (int col = 0; col < 255; ++col)
      {
        // mean=[136.20, 141.50, 145.41], std=[44.77, 44.20, 44.30]
        this->_p_data_1[col + row * 255] = ((float)uc_pixel[0] - 136.20f) / 44.77f;
        this->_p_data_1[col + row * 255 + 255 * 255] = ((float)uc_pixel[1] - 141.50f) / 44.20f;
        this->_p_data_1[col + row * 255 + 255 * 255 * 2] = ((float)uc_pixel[2] - 145.41f) / 44.30f;
        uc_pixel += 3;
      }
    }
    // TensorRT1
    // Input
    TRTCHECK(cudaMemcpyAsync(_p_buffers_1[this->_input_index_1], this->_p_data_1, 1 * 3 * 255 * 255 * sizeof(float), cudaMemcpyHostToDevice, this->_cu_stream1));
    this->_trt_context[0]->enqueueV2(_p_buffers_1, this->_cu_stream1, nullptr);

    // Output
    TRTCHECK(cudaMemcpyAsync(this->_p_prob_1, _p_buffers_1[this->_output_index_1], 1 * 48 * 16 * 16 * sizeof(float), cudaMemcpyDeviceToHost, this->_cu_stream1));
    cudaStreamSynchronize(this->_cu_stream1);

    // TensorRT3
    // Input
    TRTCHECK(cudaMemcpyAsync(_p_buffers_3[this->_input_index_3_1], this->_p_prob_2, 1 * 48 * 8 * 8 * sizeof(float), cudaMemcpyHostToDevice, this->_cu_stream3));
    TRTCHECK(cudaMemcpyAsync(_p_buffers_3[this->_input_index_3_2], this->_p_prob_1, 1 * 48 * 16 * 16 * sizeof(float), cudaMemcpyHostToDevice, this->_cu_stream3));
    this->_trt_context[2]->enqueueV2(_p_buffers_3, this->_cu_stream3, nullptr);

    // Output
    TRTCHECK(cudaMemcpyAsync(this->_p_prob_3_1, _p_buffers_3[this->_output_index_3_1], 1 * 2 * 16 * 16 * sizeof(float), cudaMemcpyDeviceToHost, this->_cu_stream3));
    TRTCHECK(cudaMemcpyAsync(this->_p_prob_3_2, _p_buffers_3[this->_output_index_3_2], 1 * 4 * 16 * 16 * sizeof(float), cudaMemcpyDeviceToHost, this->_cu_stream3));
    cudaStreamSynchronize(this->_cu_stream3);

    cv::Mat pred_bbox, pred_score;
    pred_bbox = cv::Mat(20, 16, CV_32FC(4));
    pred_score = cv::Mat(16, 16, CV_32FC(2));
    memcpy(pred_bbox.data, this->_p_prob_3_2, sizeof(float) * 4 * 16 * 16);
    memcpy(pred_score.data, this->_p_prob_3_1, sizeof(float) * 2 * 16 * 16);

    // delete this->_p_prob_3_2,this->_p_prob_3_1;

    cv::Mat outputMat = pred_score.reshape(1, 1);

    // pred_score = pred_score.row(1);
    // pred_score = pred_score.reshape(0, {2, 16, 16});
    // cv::Mat scoreSoftmax; // 2x16x16
    // Softmax(pred_score, scoreSoftmax);

    // std::cout<<scoreSoftmax<<std::endl;

    // cv::Mat score = scoreSoftmax.row(1);
    //  score = score.reshape(0, {scoreSize, scoreSize});
    // std::cout <<"pred_score Mat channels: "<<pred_bbox.channels()<< " pred_bbox Mat Size: " << pred_bbox.size() << std::endl;
    // std::cout <<"pred_score Mat channels: "<<pred_score.channels()<< " pred_score Mat Size: " << pred_score.size() << std::endl;

    // squeeze操作
    // torch::Tensor cls_score_result = torch::squeeze(cls_score, 0);
    // torch::Tensor bbox_pred_result = torch::squeeze(bbox_pred, 0);
    // std::cout << "输出的shape:"<< cls_score_result.sizes() << std::endl;
    // std::cout << "输出的shape:"<< bbox_pred_result.sizes() << std::endl;

    // std::cout<<cls_score_result<<endl;

    std::vector<float> cls_score_sigmoid;

    // float* cls_score_data = (float*)cls_score.data;
    // float* cls_score_data = cls_score.channel(1);

    // cv::Mat cls_score_mat(cv::Size{16, 16}, CV_32F, cls_score_result.index({1,"..."}).data_ptr());
    // float* cls_score_data = (float*) cls_score_mat.data;

    /* @20220805，问题：cls_score_result数据范围是[-1,1],而mat的CV_32F为[0,1]， 会有精度截断, 改了还是不对*/

    float *cls_score_data = (float *)outputMat.ptr<float>(0);

    // float *cls_score_data = new float[16 * 16 * 2];

    // 复制矩阵数据到float*数组
    // std::memcpy(cls_score_data, pred_score.data, 16 * 16 * 2 * sizeof(float));

    /* debug */
    // for(int i=0;i<256; i++){
    //     cout<<this->_p_prob_3_1[i]<<endl;
    // }

    // std::cout << "cls_score_data:"<< cls_score_mat << std::endl;

    // torch::Tensor tensor_tmp = cls_score_result.index({1,"..."});
    // std::cout << "tensor_tmp:"<< cls_score_result.index({1,"..."}) << std::endl;

    cls_score_sigmoid.clear();

    int cols = pred_score.cols;
    int rows = pred_score.rows;

    for (int i = 0; i < cols * rows; i++) //
    {
      cls_score_sigmoid.push_back(sigmoid(cls_score_data[i]));
      // cout <<"cls_score_data "<< i << " : " <<sigmoid(cls_score_data[i])<< endl;
    }

    std::vector<float> pred_x1(cols * rows, 0), pred_y1(cols * rows, 0), pred_x2(cols * rows, 0), pred_y2(cols * rows, 0);

    float *bbox_pred_data1 = (float *)pred_bbox.ptr<float>(0);
    float *bbox_pred_data2 = (float *)pred_bbox.ptr<float>(1);
    float *bbox_pred_data3 = (float *)pred_bbox.ptr<float>(2);
    float *bbox_pred_data4 = (float *)pred_bbox.ptr<float>(3);

    for (int i = 0; i < rows; i++)
    {
      for (int j = 0; j < cols; j++)
      {
        pred_x1[i * cols + j] = this->grid_to_search_x[i * cols + j] - bbox_pred_data1[i * cols + j];
        pred_y1[i * cols + j] = this->grid_to_search_y[i * cols + j] - bbox_pred_data2[i * cols + j];
        pred_x2[i * cols + j] = this->grid_to_search_x[i * cols + j] + bbox_pred_data3[i * cols + j];
        pred_y2[i * cols + j] = this->grid_to_search_y[i * cols + j] + bbox_pred_data4[i * cols + j];
      }
    }

    // size penalty
    std::vector<float> w(cols * rows, 0), h(cols * rows, 0);
    for (int i = 0; i < rows; i++)
    {
      for (int j = 0; j < cols; j++)
      {
        w[i * cols + j] = pred_x2[i * cols + j] - pred_x1[i * cols + j];
        h[i * rows + j] = pred_y2[i * rows + j] - pred_y1[i * cols + j];
      }
    }

    float sz_wh = sz_whFun(target_sz);
    std::vector<float> s_c = sz_change_fun(w, h, sz_wh);
    std::vector<float> r_c = ratio_change_fun(w, h, target_sz);

    std::vector<float> penalty(rows * cols, 0);
    for (int i = 0; i < rows * cols; i++)
    {
      penalty[i] = std::exp(-1 * (s_c[i] * r_c[i] - 1) * cfg.penalty_k);
    }

    // window penalty
    std::vector<float> pscore(rows * cols, 0);
    int r_max = 0, c_max = 0;
    float maxScore = 0;

    for (int i = 0; i < rows * cols; i++)
    {
      pscore[i] = (penalty[i] * cls_score_sigmoid[i]) * (1 - cfg.window_influence) + this->window[i] * cfg.window_influence;

      if (pscore[i] > maxScore)
      {
        // get max
        maxScore = pscore[i];
        r_max = std::floor(i / rows);
        c_max = ((float)i / rows - r_max) * rows;
      }
    }

    // to real size
    float pred_x1_real = pred_x1[r_max * cols + c_max]; // pred_x1[r_max, c_max]
    float pred_y1_real = pred_y1[r_max * cols + c_max];
    float pred_x2_real = pred_x2[r_max * cols + c_max];
    float pred_y2_real = pred_y2[r_max * cols + c_max];

    float pred_xs = (pred_x1_real + pred_x2_real) / 2;
    float pred_ys = (pred_y1_real + pred_y2_real) / 2;
    float pred_w = pred_x2_real - pred_x1_real;
    float pred_h = pred_y2_real - pred_y1_real;

    float diff_xs = pred_xs - cfg.instance_size / 2;
    float diff_ys = pred_ys - cfg.instance_size / 2;

    diff_xs /= scale_z;
    diff_ys /= scale_z;
    pred_w /= scale_z;
    pred_h /= scale_z;

    target_sz.x = target_sz.x / scale_z;
    target_sz.y = target_sz.y / scale_z;

    // size learning rate
    float lr = penalty[r_max * cols + c_max] * cls_score_sigmoid[r_max * cols + c_max] * cfg.lr;

    // size rate
    auto res_xs = float(target_pos.x + diff_xs);
    auto res_ys = float(target_pos.y + diff_ys);
    float res_w = pred_w * lr + (1 - lr) * target_sz.x;
    float res_h = pred_h * lr + (1 - lr) * target_sz.y;

    target_pos.x = int(res_xs);
    target_pos.y = int(res_ys);

    target_sz.x = target_sz.x * (1 - lr) + lr * res_w;
    target_sz.y = target_sz.y * (1 - lr) + lr * res_h;

    cls_score_max = cls_score_sigmoid[r_max * cols + c_max];
    // std::cout<<"cls_score_max: "<<cls_score_max<<std::endl;
  }

  void SotTrtDetectorCUDAImpl::cudaTrackImpl(cv::Mat im, cv::Rect &output_bbox_)
  {
    cv::Point target_pos = this->state.target_pos;
    cv::Point2f target_sz = this->state.target_sz;

    float hc_z = target_sz.y + cfg.context_amount * (target_sz.x + target_sz.y);
    float wc_z = target_sz.x + cfg.context_amount * (target_sz.x + target_sz.y);
    float s_z = sqrt(wc_z * hc_z);
    float scale_z = cfg.exemplar_size / s_z;

    float d_search = (cfg.instance_size - cfg.exemplar_size) / 2;
    float pad = d_search / scale_z;
    float s_x = s_z + 2 * pad;

    /* add @20220808 to with python */
    // float s_x = s_z * (cfg.instance_size / cfg.exemplar_size);

    cv::Mat x_crop;
    x_crop = get_subwindow_tracking(im, target_pos, cfg.instance_size, int(s_x), state.channel_ave);

    // update
    target_sz.x = target_sz.x * scale_z;
    target_sz.y = target_sz.y * scale_z;

    float cls_score_max;

    this->update(x_crop, target_pos, target_sz, scale_z, cls_score_max);

    target_pos.x = std::max(0, min(state.im_w, target_pos.x));
    target_pos.y = std::max(0, min(state.im_h, target_pos.y));
    target_sz.x = float(std::max(10, min(state.im_w, int(target_sz.x))));
    target_sz.y = float(std::max(10, min(state.im_h, int(target_sz.y))));

    state.target_pos = target_pos;
    state.target_sz = target_sz;

    // target_pos.x = bbox.x + bbox.width / 2;
    // target_pos.y = bbox.y + bbox.height / 2;
    // target_sz.x = bbox.width;
    // target_sz.y = bbox.height;
    // output_bbox_.width = state.target_sz.x;
    // output_bbox_.height = state.target_sz.y;
    // output_bbox_.x = state.target_pos.x - state.target_sz.x / 2;
    // output_bbox_.y = state.target_pos.y - state.target_sz.y / 2;

    output_bbox_ = {int(target_pos.x - target_sz.x / 2.f), int(target_pos.y - target_sz.y / 2.f), int(target_sz.x), int(target_sz.y)};
  }

  // 生成每一个格点的坐标
  void SotTrtDetectorCUDAImpl::create_window()
  {
    int score_size = cfg.score_size;
    std::vector<float> hanning(score_size, 0);
    this->window.resize(score_size * score_size, 0);

    for (int i = 0; i < score_size; i++)
    {
      float w = 0.5f - 0.5f * std::cos(2 * 3.1415926535898f * i / (score_size - 1));
      hanning[i] = w;
    }
    for (int i = 0; i < score_size; i++)
    {
      for (int j = 0; j < score_size; j++)
      {
        this->window[i * score_size + j] = hanning[i] * hanning[j];
      }
    }
  }

  // 生成每一个格点的坐标
  void SotTrtDetectorCUDAImpl::create_grids()
  {
    /*
    each element of feature map on input search image
    :return: H*W*2 (position for each element)
    */
    int sz = cfg.score_size; // 16x16

    this->grid_to_search_x.resize(sz * sz, 0);
    this->grid_to_search_y.resize(sz * sz, 0);

    for (int i = 0; i < sz; i++)
    {
      for (int j = 0; j < sz; j++)
      {
        this->grid_to_search_x[i * sz + j] = j * cfg.total_stride;
        this->grid_to_search_y[i * sz + j] = i * cfg.total_stride;
      }
    }

    for (int i = 0; i < sz; i++)
    {
      for (int j = 0; j < sz; j++)
      {
        this->grid_to_search_x[i * sz + j] += cfg.instance_size/2;
        this->grid_to_search_y[i * sz + j] += cfg.instance_size/2;
      }
    }
  }

  cv::Mat SotTrtDetectorCUDAImpl::get_subwindow_tracking(cv::Mat im, cv::Point2f pos, int model_sz, int original_sz, cv::Scalar channel_ave)
  {
    float c = (float)(original_sz + 1) / 2;
    int context_xmin = std::round(pos.x - c);
    int context_xmax = context_xmin + original_sz - 1;
    int context_ymin = std::round(pos.y - c);
    int context_ymax = context_ymin + original_sz - 1;

    int left_pad = int(std::max(0, -context_xmin));
    int top_pad = int(std::max(0, -context_ymin));
    int right_pad = int(std::max(0, context_xmax - im.cols + 1));
    int bottom_pad = int(std::max(0, context_ymax - im.rows + 1));

    context_xmin += left_pad;
    context_xmax += left_pad;
    context_ymin += top_pad;
    context_ymax += top_pad;
    cv::Mat im_path_original;

    if (top_pad > 0 || left_pad > 0 || right_pad > 0 || bottom_pad > 0)
    {
      cv::Mat te_im = cv::Mat::zeros(im.rows + top_pad + bottom_pad, im.cols + left_pad + right_pad, CV_8UC3);

      cv::copyMakeBorder(im, te_im, top_pad, bottom_pad, left_pad, right_pad, cv::BORDER_CONSTANT, channel_ave);
      im_path_original = te_im(cv::Rect(context_xmin, context_ymin, context_xmax - context_xmin + 1, context_ymax - context_ymin + 1));
    }
    else
      im_path_original = im(cv::Rect(context_xmin, context_ymin, context_xmax - context_xmin + 1, context_ymax - context_ymin + 1));

    cv::Mat im_path;
    cv::resize(im_path_original, im_path, cv::Size(model_sz, model_sz));

    return im_path;
  }

}